import numpy as np
import torch
from torch.distributions import Binomial
try:
    import gtn
except:
    print("fail to import gtn, cannot use NFA_GTN")




def create_semisup_graph(log_probs, targets=None, 
                         drop_path=False, min_num_paths=1, threshold=0.95):
    b, k = log_probs.shape
    g_labels = gtn.Graph(False) 
    g_labels.add_node(start=True)

    if drop_path:
        probs = torch.exp(log_probs)
        binomial_dist = Binomial(torch.ones_like(log_probs), probs)
        binomial_mask = binomial_dist.sample()
        topk_values, topk_indices = torch.topk(probs, min_num_paths, dim=1)

    avg_paths = 0
    for i in range(b):
        g_labels.add_node(accept=(i==b-1))
        
        if drop_path and topk_values[i][0] < threshold:
            mask = binomial_mask[i]
            mask[topk_indices[i]] = 1
            indices = torch.nonzero(mask, as_tuple=True)[0]
        else:
            indices = torch.arange(k)
        avg_paths += len(indices)
        
        for j in indices:
            g_labels.add_arc(i, i+1, j)
    avg_paths /= b
    g_labels.arc_sort(False)
    return g_labels, avg_paths


def create_partial_label_graph(log_probs, targets, 
                               drop_path=False, min_num_paths=5):
    b, k = log_probs.shape
    g_labels = gtn.Graph(False) 
    g_labels.add_node(start=True)

    if drop_path:
        probs = torch.exp(log_probs)
        binomial_dist = Binomial(torch.ones_like(log_probs), probs)
        binomial_mask = binomial_dist.sample()
        min_num_paths = min(min_num_paths, int(targets.sum(dim=-1).mean().item()))
        _, topk_indices = torch.topk(probs, min_num_paths, dim=1)

    avg_paths = 0
    for i in range(b):
        g_labels.add_node(accept=(i==b-1))
        
        if drop_path:
            mask = binomial_mask[i]
            mask[topk_indices[i]] = 1
            # TODO: fix this bug
            mask = torch.logical_or(mask, targets[i])
            indices = torch.nonzero(mask, as_tuple=True)[0]
        else:
            indices = torch.nonzero(targets[i], as_tuple=True)[0]
        avg_paths += len(indices)
        
        for j in indices:
            g_labels.add_arc(i, i+1, j)
    avg_paths /= b
    g_labels.arc_sort(False)
    return g_labels, avg_paths


def create_noisy_label_graph(log_probs, targets=None, drop_path=False, min_num_paths=5):
    b, k = log_probs.shape
    g_labels = gtn.Graph(False) 
    g_labels.add_node(start=True)

    if drop_path:
        binomial_dist = Binomial(torch.ones_like(log_probs), torch.exp(log_probs))
        binomial_mask = binomial_dist.sample()
        topk_values, topk_indices = torch.topk(log_probs, min_num_paths, dim=1)

    avg_paths = 0
    for i in range(b):
        g_labels.add_node(accept=(i==b-1))
        
        if drop_path:
            mask = binomial_mask[i]
            mask[topk_indices[i]] = 1
            indices = torch.nonzero(mask, as_tuple=True)[0]
        else:
            indices = torch.arange(k)
        avg_paths += len(indices)
        
        for j in indices:
            g_labels.add_arc(i, i+1, j)
    avg_paths /= b
    g_labels.arc_sort(False)
    return g_labels, avg_paths


def create_weak_label_graph():
    pass



class NFA_GTN:
    def __init__(self, label_config, drop_path=False, min_num_paths=5):
        self.lable_config = label_config
        if label_config == 'semisup':
            self.create_label_graph = create_semisup_graph
        elif label_config == 'partial_label':
            self.create_label_graph = create_partial_label_graph
        elif label_config == 'noisy_label':
            self.create_label_graph = create_noisy_label_graph
        elif label_config == 'weak_label':
            raise NotImplementedError
        else:
            raise NotImplementedError
        self.drop_path = drop_path
        self.min_num_path = min_num_paths
    
    def create_emission_graph(self, log_probs):
        b, k = log_probs.shape
        # create emission graph
        g_emissions = gtn.linear_graph(b, k)
        g_emissions.set_weights(log_probs.cpu().numpy())
        return g_emissions
    

    def compute(self, log_probs, targets=None):
        
        b, k = log_probs.shape

        # create emission graph
        g_emissions = self.create_emission_graph(log_probs)

        # create label graph
        g_labels, avg_paths = self.create_label_graph(log_probs, targets, self.drop_path, self.min_num_path)

        # intertsection
        g_nfa = gtn.intersect(g_emissions, g_labels)

        # forward
        loss = gtn.forward_score(g_nfa)
        # backward
        gtn.backward(loss, False)

        # em_targets
        em_targets = g_emissions.grad()
        em_targets = em_targets.weights_to_numpy() # .astype(np.float64)
        em_targets = torch.from_numpy(em_targets).view(b, k).to(log_probs.device)
        em_targets = torch.clamp(em_targets, min=0, max=1.0)
        em_targets = em_targets / em_targets.sum(dim=1, keepdims=True)
        return em_targets, avg_paths


